"""
graph_utils.knn_backends
------------------------

High-performance k-NN search with optional backends:
- FAISS (recommended for large datasets)
- hnswlib (fast and memory efficient)
- Scikit-learn (universal fallback)

Each backend is auto-detected at import.
"""
from typing import Tuple, Optional 
import numpy as np
# Import scikit‑learn for kNN fallback.  When scikit‑learn is not available
# we fall back to a naive pairwise distance search.
try:
    from sklearn.neighbors import NearestNeighbors  # type: ignore[import]
    _HAS_SKLEARN = True
except Exception:
    NearestNeighbors = None  # type: ignore[assignment]
    _HAS_SKLEARN = False
import logging

# Optional backends (auto-detect)
try:
    import faiss  # type: ignore
    HAS_FAISS = True
except Exception:
    HAS_FAISS = False

try:
    import hnswlib  # type: ignore
    HAS_HNSWLIB = True
except Exception:
    HAS_HNSWLIB = False

try:
    from annoy import AnnoyIndex  # type: ignore
    HAS_ANNOY = True
except Exception:
    HAS_ANNOY = False

try:
    import pynndescent  # type: ignore
    HAS_PYNNDESCENT = True
except Exception:
    HAS_PYNNDESCENT = False

# Constants
KNN_MAX_K = 200
# --- top: extend exports ---
__all__ = [
    "fast_knn_search",
    "faiss_knn",
    "hnswlib_knn",
    "sklearn_knn_optimized",
    "annoy_knn",
    "pynndescent_knn",
    "available_knn_backends",
    "HAS_FAISS", "HAS_HNSWLIB", "HAS_ANNOY", "HAS_PYNNDESCENT",
]

# Utility for backend detection (for CLI/debug)
def available_knn_backends() -> dict:
    return {
        "faiss": HAS_FAISS,
        "hnswlib": HAS_HNSWLIB,
        "annoy": HAS_ANNOY,
        "pynndescent": HAS_PYNNDESCENT,
    }

def hnswlib_knn(X: np.ndarray, k: int, *, space: str = "l2") -> Tuple[np.ndarray, np.ndarray]:
    n, d = X.shape
    index = hnswlib.Index(space=space, dim=int(d))
    M = min(32, max(8, d // 4))
    ef_construction = max(200, k * 4)
    ef_search = max(k * 2, 50)
    index.init_index(max_elements=int(n), M=int(M), ef_construction=int(ef_construction))
    index.set_ef(int(ef_search))
    ids = np.arange(n, dtype=np.int32)
    index.add_items(X.astype(np.float32, copy=False), ids)
    indices, distances = index.knn_query(X.astype(np.float32, copy=False), k=int(k))
    return distances.astype(np.float32, copy=False), indices.astype(np.int32, copy=False)


# --- FAISS: add metric ---
def faiss_knn(
    X: np.ndarray,
    k: int,
    metric: str = "euclidean",
    metric_params: dict | None = None,
) -> Tuple[np.ndarray, np.ndarray]:
    X32 = X.astype(np.float32, copy=False)
    n, d = X32.shape
    metric = (metric or "euclidean").lower()

    if metric == "cosine":
        # normalize and use inner product => sim; convert to distance = 1 - sim
        norms = np.linalg.norm(X32, axis=1, keepdims=True) + 1e-12
        Xn = X32 / norms
        index = faiss.IndexFlatIP(d)
        index.add(Xn)
        sims, indices = index.search(Xn, k)
        distances = 1.0 - sims
        return distances.astype(np.float32, copy=False), indices.astype(np.int32, copy=False)

    # L2 path (unchanged logic)
    if n < 10_000 or d < 32:
        index = faiss.IndexFlatL2(d)
    elif n < 100_000:
        nlist = min(int(np.sqrt(n)), 1024)
        quantizer = faiss.IndexFlatL2(d)
        index = faiss.IndexIVFFlat(quantizer, d, nlist, faiss.METRIC_L2)
        index.train(X32)
        index.nprobe = min(max(1, nlist // 4), 32)
    else:
        index = faiss.IndexHNSWFlat(d, 32)
        index.hnsw.efConstruction = 200
        index.hnsw.efSearch = max(k * 2, 50)

    index.add(X32)
    distances, indices = index.search(X32, k)  # FAISS returns squared L2 already
    return distances.astype(np.float32, copy=False), indices.astype(np.int32, copy=False)


def annoy_knn(X: np.ndarray, k: int, metric: str = "euclidean", n_trees: int = 10) -> Tuple[np.ndarray, np.ndarray]:
    metric = (metric or "euclidean").lower()
    metric = "angular" if metric == "cosine" else metric
    n, d = X.shape
    index = AnnoyIndex(d, metric)
    for i in range(n):
        index.add_item(i, X[i].astype(np.float32, copy=False))
    index.build(int(n_trees))
    indices = np.empty((n, k), dtype=np.int32)
    distances = np.empty((n, k), dtype=np.float32)
    for i in range(n):
        idxs, dists = index.get_nns_by_item(i, int(k), include_distances=True)
        indices[i, :] = idxs
        distances[i, :] = dists  # angular ~ cosine-like; leave as provided
    return distances, indices

def pynndescent_knn(
    X: np.ndarray,
    k: int,
    metric: str = "euclidean",
    n_jobs: int = -1,
    random_state: int = 42,
    metric_params: dict | None = None,
) -> Tuple[np.ndarray, np.ndarray]:
    index = pynndescent.NNDescent(
        X, n_neighbors=int(k), metric=metric, n_jobs=n_jobs, random_state=random_state, metric_kwds=(metric_params or None)
    )
    indices, distances = index.neighbor_graph
    return distances.astype(np.float32, copy=False), indices.astype(np.int32, copy=False)



def sklearn_knn_optimized(X: np.ndarray, k: int, *, metric: str = "euclidean") -> Tuple[np.ndarray, np.ndarray]:
    """
    Compute k‑nearest neighbours using scikit‑learn if available, else fall back
    to a naive pairwise search.  The return value is (distances, indices).
    """
    n, d = X.shape
    if _HAS_SKLEARN and NearestNeighbors is not None:
        if n < 1000:
            algorithm = "brute"; leaf_size = 30
        elif d > 20:
            algorithm = "ball_tree"; leaf_size = max(10, n // 100)
        else:
            algorithm = "kd_tree"; leaf_size = max(15, n // 50)
        nn = NearestNeighbors(
            n_neighbors=int(k),
            metric=metric,
            algorithm=algorithm,
            leaf_size=int(leaf_size),
            n_jobs=-1,
        ).fit(X)
        distances, indices = nn.kneighbors(X)
        return distances.astype(np.float32, copy=False), indices.astype(np.int32, copy=False)
    else:
        # naive pairwise distance search
        # Compute full distance matrix.  Use float32 to limit memory usage.
        X32 = X.astype(np.float32, copy=False)
        n = X32.shape[0]
        dists = np.empty((n, n), dtype=np.float32)
        for i in range(n):
            diff = X32[i] - X32
            d2 = np.einsum('ij,ij->i', diff, diff, dtype=np.float32)
            dists[i] = np.sqrt(d2)
        # argsort and take top k
        idxs = np.argsort(dists, axis=1)[:, :k]
        rows = np.arange(n)[:, None]
        dist_k = dists[rows, idxs]
        return dist_k, idxs.astype(np.int32)

# keep annoy_knn / pynndescent_knn as-is (they already accept metric)

def fast_knn_search(
    X: np.ndarray,
    k: int,
    backend: str = "auto",
    *,
    metric: Optional[str] = None,
    **kwargs
) -> Tuple[np.ndarray, np.ndarray]:
    k = max(1, min(int(k), KNN_MAX_K))
    backend = backend.lower()

    if backend == "faiss" and HAS_FAISS:
        # faiss uses L2/IP; metric is ignored here (handled by metrics.pre/post if needed)
        return faiss_knn(X, k)
    elif backend == "hnswlib" and HAS_HNSWLIB:
        space = metric if metric is not None else "l2"
        return hnswlib_knn(X, k, space=space)
    elif backend == "annoy" and HAS_ANNOY:
        return annoy_knn(X, k, metric=(metric or "euclidean"), n_trees=kwargs.get("n_trees", 10))
    elif backend == "pynndescent" and HAS_PYNNDESCENT:
        return pynndescent_knn(
            X, k, metric=(metric or "euclidean"),
            n_jobs=kwargs.get("n_jobs", -1),
            random_state=kwargs.get("random_state", 42)
        )
    else:
        if backend not in ("auto", "sklearn"):
            logging.warning(f"Falling back to sklearn kNN (backend={backend!r} unavailable).")
        return sklearn_knn_optimized(X, k, metric=(metric or "euclidean"))